/*
 * Copyright 2023-2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.ai.vertexai.embedding.text;

import java.util.List;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;

import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(classes = VertexAiTextEmbeddingModelIT.Config.class)
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*")
class VertexAiTextEmbeddingModelIT {

	// https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/textembedding-gecko?project=gen-lang-client-0587361272

	@Autowired
	private VertexAiTextEmbeddingModel embeddingModel;

	@ParameterizedTest(name = "{0} : {displayName} ")
	@ValueSource(strings = { "text-embedding-004", "text-multilingual-embedding-002" })
	void defaultEmbedding(String modelName) {
		assertThat(this.embeddingModel).isNotNull();

		var options = VertexAiTextEmbeddingOptions.builder().model(modelName).build();

		EmbeddingResponse embeddingResponse = this.embeddingModel
			.call(new EmbeddingRequest(List.of("Hello World", "World is Big"), options));

		assertThat(embeddingResponse.getResults()).hasSize(2);
		assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768);
		assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(768);
		assertThat(embeddingResponse.getMetadata().getModel()).as("Model name in metadata should match expected model")
			.isEqualTo(modelName);

		assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens())
			.as("Total tokens in metadata should be 5")
			.isEqualTo(5L);

		assertThat(this.embeddingModel.dimensions()).isEqualTo(768);
	}

	// Fixing https://github.com/spring-projects/spring-ai/issues/2168
	@Test
	void testTaskTypeProperty() {
		// Use text-embedding-005 model
		VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
			.model("text-embedding-005")
			.taskType(VertexAiTextEmbeddingOptions.TaskType.RETRIEVAL_DOCUMENT)
			.build();

		String text = "Test text for embedding";

		// Generate embedding using Spring AI with RETRIEVAL_DOCUMENT task type
		EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options));

		assertThat(embeddingResponse.getResults()).hasSize(1);
		assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotNull();

		// Get the embedding result
		float[] springAiEmbedding = embeddingResponse.getResults().get(0).getOutput();

		// Now generate the same embedding using Google SDK directly with
		// RETRIEVAL_DOCUMENT
		float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT");

		// Also generate embedding using Google SDK with RETRIEVAL_QUERY (which is the
		// default)
		float[] googleSdkQueryEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_QUERY");

		// Spring AI embedding should match with what gets generated by Google SDK with
		// RETRIEVAL_DOCUMENT task type.
		assertThat(springAiEmbedding)
			.as("Spring AI embedding with RETRIEVAL_DOCUMENT should match Google SDK RETRIEVAL_DOCUMENT embedding")
			.isEqualTo(googleSdkDocumentEmbedding);

		// Spring AI embedding which uses RETRIEVAL_DOCUMENT task_type should not match
		// with what gets generated by
		// Google SDK with RETRIEVAL_QUERY task type.
		assertThat(springAiEmbedding)
			.as("Spring AI embedding with RETRIEVAL_DOCUMENT should NOT match Google SDK RETRIEVAL_QUERY embedding")
			.isNotEqualTo(googleSdkQueryEmbedding);
	}

	// Fixing https://github.com/spring-projects/spring-ai/issues/2168
	@Test
	void testDefaultTaskTypeBehavior() {
		// Test default behavior without explicitly setting task type
		VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
			.model("text-embedding-005")
			.build();

		String text = "Test text for default embedding";

		EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options));

		assertThat(embeddingResponse.getResults()).hasSize(1);

		float[] springAiDefaultEmbedding = embeddingResponse.getResults().get(0).getOutput();

		// According to documentation, default should be RETRIEVAL_DOCUMENT
		float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT");

		assertThat(springAiDefaultEmbedding)
			.as("Default Spring AI embedding should match Google SDK RETRIEVAL_DOCUMENT embedding")
			.isEqualTo(googleSdkDocumentEmbedding);
	}

	private float[] getEmbeddingUsingGoogleSdk(String text, String taskType) {
		try {
			String endpoint = String.format("%s-aiplatform.googleapis.com:443",
					System.getenv("VERTEX_AI_GEMINI_LOCATION"));
			String project = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID");

			PredictionServiceSettings settings = PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();

			EndpointName endpointName = EndpointName.ofProjectLocationPublisherModelName(project,
					System.getenv("VERTEX_AI_GEMINI_LOCATION"), "google", "text-embedding-005");

			try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {
				PredictRequest.Builder request = PredictRequest.newBuilder().setEndpoint(endpointName.toString());

				request.addInstances(Value.newBuilder()
					.setStructValue(Struct.newBuilder()
						.putFields("content", Value.newBuilder().setStringValue(text).build())
						.putFields("task_type", Value.newBuilder().setStringValue(taskType).build())
						.build())
					.build());

				var prediction = client.predict(request.build()).getPredictionsList().get(0);
				Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
				Value values = embeddings.getStructValue().getFieldsOrThrow("values");

				List<Float> floatList = values.getListValue()
					.getValuesList()
					.stream()
					.map(Value::getNumberValue)
					.map(Double::floatValue)
					.toList();

				float[] floatArray = new float[floatList.size()];
				for (int i = 0; i < floatList.size(); i++) {
					floatArray[i] = floatList.get(i);
				}
				return floatArray;
			}
		}
		catch (Exception e) {
			throw new RuntimeException("Failed to get embedding from Google SDK", e);
		}
	}

	@SpringBootConfiguration
	static class Config {

		@Bean
		public VertexAiEmbeddingConnectionDetails connectionDetails() {
			return VertexAiEmbeddingConnectionDetails.builder()
				.projectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
				.location(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
				.build();
		}

		@Bean
		public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails) {

			VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
				.model(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME)
				.build();

			return new VertexAiTextEmbeddingModel(connectionDetails, options);
		}

	}

}
